classdef LogRegCost < matlab.mixin.SetGet
    %LOGREGCOST Summary of this class goes here
    %   Logistic cost function class
    %   logistic cost function=1/n*\sum_{i=1}^{n}log(1+exp(-y_iw^Tx_i))
    %   Properties:
    %       w = [tilde_w, b]
    %       
    properties
        X            % data matrix(mxn): # of samples(m), # of features(n).
        XF           % free variables
        num_features % number of features
        num_samples  % number of samples
        y            % labels of samples
        expterm      % save exp(-y_iw^Tx_i+b) in vector
        sigmoid      % save exp(t)/(1+exp(t))
        diag         % Diag matrix for hessian in array
    end
    
    methods
        
        %construction method
        function obj=LogRegCost(X,y)
            obj.num_samples=length(y);
            obj.num_features=size(X,2)+1;  % add one dummy feature for bias
            obj.X=[X, ones(length(y), 1)];  % add one dummy feature all 1's
            obj.y=y;
        end
        
        %return the logistic cost value 
        function f=func(obj, indexes)
            f = 1/length(indexes) * sum(log(1+obj.expterm));
            % f=1/obj.num_samples*sum(log(1+obj.expterm));
        end
        
        % calculate the gradient of logistic cost function
        function g = grad(obj, indexes)
            g = -1/length(indexes)*((obj.sigmoid.*obj.y(indexes))'*obj.X(indexes, :))';
        end        
        
        %% set properties
        function obj=setExpterm(obj,w,indexes)            
            obj.expterm = exp(-1*(obj.y(indexes)).*(obj.X(indexes, :)*w));
        end
        
        function obj=setDiag(obj)
            obj.diag=obj.sigmoid-(obj.sigmoid).^2;
        end
        
        function obj=setXF(obj,F)
            % F is a vector that stores the indices of the free variables
            obj.XF=obj.X(:,F);
        end
        
        function obj=setSigmoid(obj)
            obj.sigmoid=obj.expterm./(1+obj.expterm);
        end
        
        function obj=setX(obj,X)
            set(obj,'X',X);
        end
        
        function obj=setY(obj,y)
            set(obj,'y',y);
        end    
        
        function obj=setNumberfeatures(obj,num_features)
            set(obj,'num_features',num_features);
        end 
        
        function obj=setNumbersamples(obj,num_samples)
            set(obj,'num_samples',num_samples);
        end       
        
    end
    
end

